{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "PyTorch on Cloud TPUs: MultiCore Training AlexNet on Fashion MNIST",
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "foxdQ7m3RsOk",
        "colab_type": "text"
      },
      "source": [
        "##PyTorch on Cloud TPUs: MultiCore Training AlexNet on Fashion MNIST \n",
        "\n",
        "This notebook will show you how to train [AlexNet](https://arxiv.org/abs/1404.5997) on the [Fashion MNIST dataset](https://github.com/zalandoresearch/fashion-mnist) using a Cloud TPU and all eight of its cores. It's a follow-up to [this notebook](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/single-core-alexnet-fashion-mnist.ipynb), which trains the same network on the same dataset the using a single Cloud TPU core. This will show you how to train your own networks on a Cloud TPU and highlight the difference between single and multicore training.\n",
        "\n",
        "This notebook is part of a series of tutorials on using PyTorch on Cloud TPUs. PyTorch can use Cloud TPU cores as devices with the PyTorch/XLA package. For more on PyTorch/XLA see its [Github](https://github.com/pytorch/xla) or its [documentation](http://pytorch.org/xla/). We also have a [\"Getting Started\"](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb) Colab notebook. Additional Colab notebooks, like this one, are available on the PyTorch/XLA Github linked above."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_UYwdM3qRjhS",
        "colab_type": "text"
      },
      "source": [
        "### Installing PyTorch/XLA\n",
        "\n",
        "Run the following cell (or copy it into your own notebook!) to install PyTorch, Torchvision, and PyTorch/XLA. It will take a couple minutes to run."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "imKiCg3wReXn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Installs PyTorch, PyTorch/XLA, and Torchvision\n",
        "# Copy this cell into your own notebooks to use PyTorch on Cloud TPUs \n",
        "# Warning: this may take a couple minutes to run\n",
        "\n",
        "import os\n",
        "assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'\n",
        "\n",
        "VERSION = \"20200325\"  #@param [\"1.5\" , \"20200325\", \"nightly\"]\n",
        "!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n",
        "!python pytorch-xla-env-setup.py --version $VERSION"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kLe_HkWzUCHP",
        "colab_type": "text"
      },
      "source": [
        "### Dataset & Network\n",
        "\n",
        "In this notebook we'll train AlexNet on the Fashion MNIST dataset. Both are provided by the [Torchvision package](https://pytorch.org/docs/stable/torchvision/index.html).\n",
        "\n",
        "The [previous notebook](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/single-core-alexnet-fashion-mnist.ipynb) trained this combination on a single TPU core and took time to visualize and describe the dataset. To avoid redundancy, please refer to it to learn more about the Fashion MNIST dataset. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hzXONhL6Wa-_",
        "colab_type": "text"
      },
      "source": [
        "### Using Multiple Cloud TPU Cores\n",
        "\n",
        "Working with multiple Cloud TPU cores is different than training on a single Cloud TPU core. With a single Cloud TPU core we simply acquired the device and ran the operations using it directly. To use multiple Cloud TPU cores we must use other processes, one per Cloud TPU core. This indirection and multiplicity makes multicore training a little more complex than training on a single core, but it's necessary to maximize performance.\n",
        "\n",
        "The following cell shows how to launch one process per Cloud TPU core.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eSiwkKVLW9zO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "import torch_xla\n",
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "\n",
        "# \"Map function\": acquires a corresponding Cloud TPU core, creates a tensor on it,\n",
        "# and prints its core\n",
        "def simple_map_fn(index, flags):\n",
        "  # Sets a common random seed - both for initialization and ensuring graph is the same\n",
        "  torch.manual_seed(1234)\n",
        "\n",
        "  # Acquires the (unique) Cloud TPU core corresponding to this process's index\n",
        "  device = xm.xla_device()  \n",
        "\n",
        "  # Creates a tensor on this process's device\n",
        "  t = torch.randn((2, 2), device=device)\n",
        "\n",
        "  print(\"Process\", index ,\"is using\", xm.xla_real_devices([str(device)])[0])\n",
        "\n",
        "  # Barrier to prevent master from exiting before workers connect.\n",
        "  xm.rendezvous('init')\n",
        "\n",
        "# Spawns eight of the map functions, one for each of the eight cores on\n",
        "# the Cloud TPU\n",
        "flags = {}\n",
        "# Note: Colab only supports start_method='fork'\n",
        "xmp.spawn(simple_map_fn, args=(flags,), nprocs=8, start_method='fork')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iq8D9sGwu_MT",
        "colab_type": "text"
      },
      "source": [
        "Let's start looking at the above cell with the call to `spawn(),` which is documented [here](http://pytorch.org/xla/#torch_xla.distributed.xla_multiprocessing.spawn). `spawn()` takes a function (the \"map function\"), a tuple of arguments (the placeholder `flags` dict), the number of processes to create, and whether to create these new processes by \"forking\" or \"spawning.\" While spawning new processes is generally recommended, Colab only supports forking.\n",
        "\n",
        "`spawn()` will create eight processes, one for each Cloud TPU core, and call `simple_map_fn()` -- the map function -- on each process. The inputs to `simple_map_fn()` are an index (zero through seven) and the placeholder `flags.` When the proccesses acquire their device they actually acquire their corresponding Cloud TPU core automatically, as the above cell demonstrates.\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_Zi_7MF4-nBs",
        "colab_type": "text"
      },
      "source": [
        "### An Aside on Context\n",
        "\n",
        "How did each process in the above cell know to acquire its own Cloud TPU core?\n",
        "\n",
        "The answer is context. Accelerators, like Cloud TPUs, manage their operations using an implicit stateful context. In the cell above, the `spawn()` function creates a multiprocessing context and gives it to each new, forked process, allowing them to coordinate. We'll see another example of this coordination below.\n",
        "\n",
        "Two warnings before we continue! First, you can't mix single process and multiprocess contexts when forking! Practically, this means that all our Cloud TPU-related calls will be done in processes created by spawn."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ej0qnTFC-p6m",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Don't mix these!\n",
        "# Only one type of context per Colab!\n",
        "# Warning: uncommenting the below and running this cell will cause a runtime error!\n",
        "\n",
        "#device = xm.xla_device()  # Requires a single process context\n",
        "\n",
        "#xmp.spawn(simple_map_fn, args=(flags,), nprocs=8, start_method='fork')  # Requires a multiprocess context"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MtXxG6whFR90",
        "colab_type": "text"
      },
      "source": [
        "The second warning is: each process should perform the same Cloud TPU computations!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HveR4CZtFaOW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Don't perform different computations on different processes!\n",
        "# Warning: uncommenting the below and running this cell will likely hang your Colab!\n",
        "# def simple_map_fn(index, flags):\n",
        "#   torch.manual_seed(1234)\n",
        "#   device = xm.xla_device()  \n",
        "\n",
        "#   if xm.is_master_ordinal():\n",
        "#     t = torch.randn((2, 2), device=device)  # Divergent Cloud TPU computation!\n",
        "\n",
        "\n",
        "# xmp.spawn(simple_map_fn, args=(flags,), nprocs=8, start_method='fork')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "alwue4JCF3Le",
        "colab_type": "text"
      },
      "source": [
        "Performing the same Cloud TPU computations lets the context coordinate the processes correctly. Again, we'll see this better below.\n",
        "\n",
        "Note, however, you CAN perform different CPU computations in each process, as the next cell demonstrates."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RHS6iAUzGWwt",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Common Cloud TPU computation but different CPU computation is OK\n",
        "def simple_map_fn(index, flags):\n",
        "  torch.manual_seed(1234)\n",
        "  device = xm.xla_device()  \n",
        "\n",
        "  t = torch.randn((2, 2), device=device)  # Common Cloud TPU computation\n",
        "  out = str(t)  # Each process uses the XLA tensors the same way\n",
        "\n",
        "  if xm.is_master_ordinal():  # Divergent CPU-only computation (no XLA tensors beyond this point!)\n",
        "    print(out)\n",
        "\n",
        "  # Barrier to prevent master from exiting before workers connect.\n",
        "  xm.rendezvous('init')\n",
        "\n",
        "\n",
        "xmp.spawn(simple_map_fn, args=(flags,), nprocs=8, start_method='fork')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "21WNWICABAo1",
        "colab_type": "text"
      },
      "source": [
        "### Multicore Training\n",
        "\n",
        "The following cell defines a map function that trains AlexNet on the Fashion MNIST dataset on all eight available Cloud TPU cores. The function is long and contained in a single cell, so it includes lengthy comments. It does the following:\n",
        "\n",
        "- **Setup**: every process sets the same random seed and acquires the device assigned to it (via the accelerator context, see above)\n",
        "- **Dataloading**: each process acquires its own copy of the dataset, but their sampling from it is coordinated to not overlap.\n",
        "- **Network creation**: each process has its own copy of the network, but these copies are identical since each process's random seed is the same.\n",
        "- **Training** and **Evaluation**: Training and evaluation occur as usual but use a ParallelLoader.\n",
        "\n",
        "Aside from a couple different classes, like the DistributedSampler and the ParallelLoader, the big difference between single core and multicore training is behind the scenes. The `step()` function now not only propagates gradients, but uses the Cloud TPU context to synchronize gradient updates across each processes' copy of the network. This ensures that each processes' network copy stays \"in sync\" (they are all identical). "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pqx-VgEizPiF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torchvision\n",
        "from torchvision import datasets\n",
        "import torchvision.transforms as transforms\n",
        "import torch_xla.distributed.parallel_loader as pl\n",
        "import time\n",
        "\n",
        "def map_fn(index, flags):\n",
        "  ## Setup \n",
        "\n",
        "  # Sets a common random seed - both for initialization and ensuring graph is the same\n",
        "  torch.manual_seed(flags['seed'])\n",
        "\n",
        "  # Acquires the (unique) Cloud TPU core corresponding to this process's index\n",
        "  device = xm.xla_device()  \n",
        "\n",
        "\n",
        "  ## Dataloader construction\n",
        "\n",
        "  # Creates the transform for the raw Torchvision data\n",
        "  # See https://pytorch.org/docs/stable/torchvision/models.html for normalization\n",
        "  # Pre-trained TorchVision models expect RGB (3 x H x W) images\n",
        "  # H and W should be >= 224\n",
        "  # Loaded into [0, 1] and normalized as follows:\n",
        "  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
        "                                  std=[0.229, 0.224, 0.225])\n",
        "  to_rgb = transforms.Lambda(lambda image: image.convert('RGB'))\n",
        "  resize = transforms.Resize((224, 224))\n",
        "  my_transform = transforms.Compose([resize, to_rgb, transforms.ToTensor(), normalize])\n",
        "\n",
        "  # Downloads train and test datasets\n",
        "  # Note: master goes first and downloads the dataset only once (xm.rendezvous)\n",
        "  #   all the other workers wait for the master to be done downloading.\n",
        "\n",
        "  if not xm.is_master_ordinal():\n",
        "    xm.rendezvous('download_only_once')\n",
        "\n",
        "  train_dataset = datasets.FashionMNIST(\n",
        "    \"/tmp/fashionmnist\",\n",
        "    train=True,\n",
        "    download=True,\n",
        "    transform=my_transform)\n",
        "\n",
        "  test_dataset = datasets.FashionMNIST(\n",
        "    \"/tmp/fashionmnist\",\n",
        "    train=False,\n",
        "    download=True,\n",
        "    transform=my_transform)\n",
        "  \n",
        "  if xm.is_master_ordinal():\n",
        "    xm.rendezvous('download_only_once')\n",
        "  \n",
        "  # Creates the (distributed) train sampler, which let this process only access\n",
        "  # its portion of the training dataset.\n",
        "  train_sampler = torch.utils.data.distributed.DistributedSampler(\n",
        "    train_dataset,\n",
        "    num_replicas=xm.xrt_world_size(),\n",
        "    rank=xm.get_ordinal(),\n",
        "    shuffle=True)\n",
        "  \n",
        "  test_sampler = torch.utils.data.distributed.DistributedSampler(\n",
        "    test_dataset,\n",
        "    num_replicas=xm.xrt_world_size(),\n",
        "    rank=xm.get_ordinal(),\n",
        "    shuffle=False)\n",
        "  \n",
        "  # Creates dataloaders, which load data in batches\n",
        "  # Note: test loader is not shuffled or sampled\n",
        "  train_loader = torch.utils.data.DataLoader(\n",
        "      train_dataset,\n",
        "      batch_size=flags['batch_size'],\n",
        "      sampler=train_sampler,\n",
        "      num_workers=flags['num_workers'],\n",
        "      drop_last=True)\n",
        "\n",
        "  test_loader = torch.utils.data.DataLoader(\n",
        "      test_dataset,\n",
        "      batch_size=flags['batch_size'],\n",
        "      sampler=test_sampler,\n",
        "      shuffle=False,\n",
        "      num_workers=flags['num_workers'],\n",
        "      drop_last=True)\n",
        "  \n",
        "\n",
        "  ## Network, optimizer, and loss function creation\n",
        "\n",
        "  # Creates AlexNet for 10 classes\n",
        "  # Note: each process has its own identical copy of the model\n",
        "  #  Even though each model is created independently, they're also\n",
        "  #  created in the same way.\n",
        "  net = torchvision.models.alexnet(num_classes=10).to(device).train()\n",
        "\n",
        "  loss_fn = torch.nn.CrossEntropyLoss()\n",
        "  optimizer = torch.optim.Adam(net.parameters())\n",
        "\n",
        "\n",
        "  ## Trains\n",
        "  train_start = time.time()\n",
        "  for epoch in range(flags['num_epochs']):\n",
        "    para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)\n",
        "    for batch_num, batch in enumerate(para_train_loader):\n",
        "      data, targets = batch \n",
        "\n",
        "      # Acquires the network's best guesses at each class\n",
        "      output = net(data)\n",
        "\n",
        "      # Computes loss\n",
        "      loss = loss_fn(output, targets)\n",
        "\n",
        "      # Updates model\n",
        "      optimizer.zero_grad()\n",
        "      loss.backward()\n",
        "\n",
        "      # Note: optimizer_step uses the implicit Cloud TPU context to\n",
        "      #  coordinate and synchronize gradient updates across processes.\n",
        "      #  This means that each process's network has the same weights after\n",
        "      #  this is called.\n",
        "      # Warning: this coordination requires the actions performed in each \n",
        "      #  process are the same. In more technical terms, the graph that\n",
        "      #  PyTorch/XLA generates must be the same across processes. \n",
        "      xm.optimizer_step(optimizer)  # Note: barrier=True not needed when using ParallelLoader \n",
        "\n",
        "  elapsed_train_time = time.time() - train_start\n",
        "  print(\"Process\", index, \"finished training. Train time was:\", elapsed_train_time) \n",
        "\n",
        "\n",
        "  ## Evaluation\n",
        "  # Sets net to eval and no grad context \n",
        "  net.eval()\n",
        "  eval_start = time.time()\n",
        "  with torch.no_grad():\n",
        "    num_correct = 0\n",
        "    total_guesses = 0\n",
        "\n",
        "    para_train_loader = pl.ParallelLoader(test_loader, [device]).per_device_loader(device)\n",
        "    for batch_num, batch in enumerate(para_train_loader):\n",
        "      data, targets = batch\n",
        "\n",
        "      # Acquires the network's best guesses at each class\n",
        "      output = net(data)\n",
        "      best_guesses = torch.argmax(output, 1)\n",
        "\n",
        "      # Updates running statistics\n",
        "      num_correct += torch.eq(targets, best_guesses).sum().item()\n",
        "      total_guesses += flags['batch_size']\n",
        "  \n",
        "  elapsed_eval_time = time.time() - eval_start\n",
        "  print(\"Process\", index, \"finished evaluation. Evaluation time was:\", elapsed_eval_time)\n",
        "  print(\"Process\", index, \"guessed\", num_correct, \"of\", total_guesses, \"correctly for\", num_correct/total_guesses * 100, \"% accuracy.\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8rDe9KaS1mzz",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Configures training (and evaluation) parameters\n",
        "flags['batch_size'] = 32\n",
        "flags['num_workers'] = 8\n",
        "flags['num_epochs'] = 1\n",
        "flags['seed'] = 1234\n",
        "\n",
        "xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gkhC6gt1N1c8",
        "colab_type": "text"
      },
      "source": [
        "The network should take about 30 seconds to train and about 10 seconds to evaluate on each process. Using an entire Cloud TPU is, as expected, dramatically faster than training and evaluating on a single Cloud TPU core.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WSp0EOq8Pg2b",
        "colab_type": "text"
      },
      "source": [
        "##What's Next?\n",
        "\n",
        "This notebook broke down training AlexNet on the Fashion MNIST dataset using an entire Cloud TPU. A [previous notebook](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/single-core-alexnet-fashion-mnist.ipynb) showed how to train AlexNet on Fashion MNIST using only a single Cloud TPU core, and can be a helpful point of comparison. \n",
        "\n",
        "In particular, this notebook showed us how to:\n",
        "\n",
        "- Define a \"map function\" that runs in parallel on one process per Cloud TPU core. \n",
        "- Run the map function using `spawn`.\n",
        "- Understand the Cloud TPU context, its benefits, like automatic cross-process coordination, and its limits, like needing each process to perform the same Cloud TPU operations.\n",
        "- Load and sample the datasets.\n",
        "- Train and evaluate the network.\n",
        "\n",
        "Additional notebooks demonstrating how to run PyTorch on Cloud TPUs can be found [here](https://github.com/pytorch/xla). While Colab provides a free Cloud TPU, training is even faster on [Google Cloud Platform](https://github.com/pytorch/xla#Cloud), especially when using multiple Cloud TPUs in a Cloud TPU pod. Scaling from a single Cloud TPU, like in this notebook, to many Cloud TPUs in a pod is easy, too. You use the same code as this notebook and just spawn more processes."
      ]
    }
  ]
}
